package com.opencee.cloud.msg.websocket;

import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;
import org.springframework.util.StringUtils;

import javax.websocket.*;
import javax.websocket.server.PathParam;
import javax.websocket.server.ServerEndpoint;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;

/**
 * webscocket端点
 *
 * @author
 */
@Slf4j
@ServerEndpoint(value = "/ws/{userId}")
@Component
public class WebSocketEndpoint {

    private static ConcurrentHashMap<String, Map<String, Session>> webSocketMap = new ConcurrentHashMap<String, Map<String, Session>>();

    /**
     * 连接建立成功调用的方法
     */
    @OnOpen
    public void onOpen(@PathParam("userId") String userId, Session session) {
        Map<String, Session> sessionMap = null;
        if (webSocketMap.containsKey(userId)) {
            sessionMap = webSocketMap.get(userId);
            sessionMap.put(session.getId(), session);
            webSocketMap.put(userId, sessionMap);
        } else {
            sessionMap = new HashMap<>();
            sessionMap.put(session.getId(), session);
            webSocketMap.put(userId, sessionMap);
        }
        if (log.isDebugEnabled()) {
            log.debug("用户连接:{},{}", userId, sessionMap.size());
        }
    }

    /**
     * 连接关闭调用的方法
     */
    @OnClose
    public void onClose(@PathParam("userId") String userId, Session session) {
        int count = 0;
        if (webSocketMap.containsKey(userId)) {
            Map<String, Session> sessionMap = webSocketMap.get(userId);
            sessionMap.remove(session.getId());
            if (sessionMap != null && sessionMap.size() == 0) {
                webSocketMap.remove(userId);
            }
            count = sessionMap.size();
        }
        if (log.isDebugEnabled()) {
            log.debug("用户退出:{},{}", userId, count);
        }
    }

    /**
     * 收到客户端消息后调用的方法
     *
     * @param message 客户端发送过来的消息
     */
    @OnMessage
    public void onMessage(@PathParam("userId") String userId, String message, Session session) {
        if (log.isDebugEnabled()) {
            log.debug("用户消息:{},报文:{}", userId, message);
        }
    }

    /**
     * @param session
     * @param error
     */
    @OnError
    public void onError(@PathParam("userId") String userId, Session session, Throwable error) {
        log.error("用户错误:", error);
    }

    /**
     * 发送自定义消息
     */
    public synchronized static void sendMessage(String message, String userId) {
        if (log.isDebugEnabled()) {
            log.debug("发送消息到:{}，报文:{}", userId, message);
        }
        try {
            if (!StringUtils.isEmpty(userId) && webSocketMap.containsKey(userId)) {
                Map<String, Session> sessionMap = webSocketMap.get(userId);
                if (sessionMap != null && sessionMap.size() > 0) {
                    Set<String> keySet = sessionMap.keySet();
                    for (String key : keySet) {
                        Session session = sessionMap.get(key);
                        if (session != null && session.isOpen()) {
                            session.getBasicRemote().sendText(message);
                        }
                    }
                }
            } else {
                if (log.isDebugEnabled()) {
                    log.warn("用户{},不在线！", userId);
                }
            }
        } catch (Exception e) {
            log.error("发送消息错误:{}，报文:{}", userId, message, e);
        }
    }

}
